# -*- coding: utf-8 -*-
"""
分训练集、验证集和测试集，按照 8：1：1 的比例来分，训练集8，验证集1，测试集1

"""

import random
import argparse
import os


parser = argparse.ArgumentParser()
# xml文件的地址，根据自己的数据进行修改 xml一般存放在Annotations下
parser.add_argument('--ana_path', default=r"E:\Bdd_yolov5\data\nightday_label", type=str, help='input xml label path')
# 数据集的划分，地址选择自己数据下的ImageSets/Main
parser.add_argument('--txt_path', default=r"E:\Bdd_yolov5\data\ImageSets\Main1", type=str, help='output txt label path')
opt = parser.parse_args()

# 创建ImageSets文件夹，并生成test.txt,train.txt,trainval.txt,val.txt文件

train_percent = 0.8  # 训练集所占比例
val_percent = 0.1  # 验证集所占比例
test_persent = 0.1  # 测试集所占比例

anafilepath = opt.ana_path
txtsavepath = opt.txt_path
total_ana = os.listdir(anafilepath)

if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

num = len(total_ana)
list = list(range(num))

t_train = int(num * train_percent)
t_val = int(num * val_percent)

train = random.sample(list, t_train)
num1 = len(train)
for i in range(num1):
    list.remove(train[i])

val_test = [i for i in list if not i in train]
val = random.sample(val_test, t_val)
num2 = len(val)
for i in range(num2):
    list.remove(val[i])

file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')


for i in train:

    name = 'E:/Bdd_yolov5/data/JPEGImages/'+total_ana[i][:-4] +'.jpg' '\n'
    file_train.write(name)

for i in val:

    name = 'E:/Bdd_yolov5/data/JPEGImages/'+total_ana[i][:-4] +'.jpg' '\n'
    file_val.write(name)

for i in list:

    name = 'E:/Bdd_yolov5/data/JPEGImages/'+total_ana[i][:-4] +'.jpg' '\n'
    file_test.write(name)

file_train.close()
file_val.close()
file_test.close()
